# ===============================================
# RX 480 Prismatic Throughput Analyzer
# Auto-hunts optimal recursion depths
# ===============================================
import pyopencl as cl
import numpy as np
import time
from itertools import combinations

# ---------------------------
# Detect GPU
# ---------------------------
device = None
for platform in cl.get_platforms():
    for d in platform.get_devices():
        if d.type & cl.device_type.GPU:
            device = d
            break
    if device:
        break

if device is None:
    raise RuntimeError("No GPU found.")

ctx = cl.Context([device])
queue = cl.CommandQueue(ctx)
print("Using device:", device.name)
print("Global Memory (MB):", device.global_mem_size // 1024**2)
print("Compute Units:", device.max_compute_units)
print("Max Clock (MHz):", device.max_clock_frequency)

# ---------------------------
# Kernel: Prismatic Base(∞) recursion
# ---------------------------
kernel_code = """
__kernel void recurse_prismatic_inf(
    __global double *data,
    const int depth_max,
    const double phi)
{
    int gid = get_global_id(0);
    double x = data[gid];

    for(int d=0; d<depth_max; d++){
        double factor = pow(phi, d/16.0);  // fractional exponent for prismatic braid
        x = sqrt(x * factor + 0.5) * 1.0001;
    }

    data[gid] = x;
}
"""
program = cl.Program(ctx, kernel_code).build()
kernel = cl.Kernel(program, "recurse_prismatic_inf")  # reuse kernel

# ---------------------------
# Benchmark params
# ---------------------------
N = 2**20               # 1M threads for full saturation
phi = 1.6180339887
depth_start = 7
depth_end = 50           # explore up to depth 50
data = np.random.rand(N).astype(np.float64)
buf = cl.Buffer(ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=data)

results = []

# ---------------------------
# Benchmark each depth
# ---------------------------
for depth in range(depth_start, depth_end + 1):
    kernel.set_arg(0, buf)
    kernel.set_arg(1, np.int32(depth))
    kernel.set_arg(2, np.float64(phi))

    # Warmup
    evt = cl.enqueue_nd_range_kernel(queue, kernel, (N,), None)
    evt.wait()

    # Timed run
    t0 = time.time()
    evt = cl.enqueue_nd_range_kernel(queue, kernel, (N,), None)
    evt.wait()
    dt = max(time.time() - t0, 1e-9)  # safe division
    fps = 1.0 / dt
    flops = (N * depth) / dt / 1e9
    vram_mb = data.nbytes / 1024**2

    print(f"Depth {depth:2d} | N={N:,} | VRAM={vram_mb:.1f} MB | {fps:.2f} FPS | {flops:.2f} GFLOPs")
    results.append((depth, fps, flops))

# ---------------------------
# Hunt for highest throughput combinations
# ---------------------------
# Find top depths individually
top_depths = sorted(results, key=lambda x: x[2], reverse=True)[:10]  # top 10 by GFLOPs
print("\nTop individual depths by GFLOPs:")
for d, fps, flops in top_depths:
    print(f"Depth {d} -> {flops:.2f} GFLOPs at {fps:.2f} FPS")

# Optional: Evaluate combinations (sum throughput)
combo_sizes = [2, 3]  # pairs and triples
for sz in combo_sizes:
    best_combo = None
    best_combo_flops = 0
    for combo in combinations([r[0] for r in results], sz):
        combo_flops = sum([r[2] for r in results if r[0] in combo])
        if combo_flops > best_combo_flops:
            best_combo_flops = combo_flops
            best_combo = combo
    print(f"Best combo of size {sz}: {best_combo} -> {best_combo_flops:.2f} GFLOPs total")
